from ast import If
from cProfile import label
from fileinput import filename
from genericpath import isdir, isfile
from operator import truediv
from tokenize import Triple

from turtle import update
from scipy.io import savemat, loadmat

import numpy as np
import cvxpy as cvx

from data_utils import get_dataset, vstack

import matplotlib.pyplot as plt
import os
import argparse

import warnings
import time
import pickle

from utils import get_needed_dirs, train_model, test_model

warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mnist_17',help="options: 2d_toy, mnist_17, dogfish,cifar10_05,cifar10_14,mnist_38,mnist_69")
parser.add_argument('--model_type',default='svm',help='victim model type: SVM or rlogistic regression')
parser.add_argument('--weight_decay',default=0.09, type=float, help='weight decay for regularizers')
parser.add_argument('--use_custom_svm',action="store_true", help='print the detailed information')
parser.add_argument('--rand_seed',default=1234, type=int, help='seed for random number generator')
parser.add_argument('--epsilon',default=0.03, type=float, help='poisoning ratio')
parser.add_argument('--obtained_epsilon',default=0.1, type=float, help='poisoning ratio that we have results for, useful for continuously running the attack')
parser.add_argument('--print_every',default=50, type=int, help='print the attack statistics every n iterations')

parser.add_argument('--flip_y',default=0.01, type=float, help='fraction of label noise')
parser.add_argument('--class_sep',default=2.0, type=float, help='class separability')
parser.add_argument('--addi_search_space',default=2.0, type=float, help='additional search space for the attack, added to min,max value of clean points')
parser.add_argument('--n_samples',default=100000, type=int, help='number of total samples')
parser.add_argument('--n_features',default=1, type=int, help='number of features for the toy dataset')
parser.add_argument('--num_sampled_points',default=50, type=int, help='how many samples to generate for grid search')
parser.add_argument('--train_frac',default=0.5, type=float, help='fraction of train data in all samples')

# attack params for algorithmic influence maximization method
parser.add_argument('--init_mode', default='current',help="options: random [randomly generate init point], current [randomly select from existing data]")
# parser.add_argument('--baseline_influence',action="store_true", help='the baseline influence attack')
parser.add_argument('--num_restart',default=10, type=int, help='how many restarts for optimizing the lr influence maximization')
parser.add_argument('--lr',default=0.01, type=float, help='learning rate when performing gradient ascend')
parser.add_argument('--num_opt_steps',default=100, type=int, help='optimization steps for lr influence maximization')
parser.add_argument('--optimizer', default='adam',help="supported optimizer options: 'gd', 'adagrad', 'adam'")
parser.add_argument('--attack', default='all',help="attack options: 'online_influence', 'baseline_influence', 'min_max', 'mta', 'all' ")
parser.add_argument('--approx_type', default='exact',help="hvp computation options: 'lissa', 'exact', 'identity' ")
# parser.add_argument('--target_model_type', default='None',help="types of target models generated: 'exhaust', 'error_highL','error_lowL','None'")
parser.add_argument('--target_error',default=0.05, type=float, help='error rate for the target model')

parser.add_argument('--generate_target',action="store_true", help='generate the target models first!')
parser.add_argument('--non_greedy_check',action="store_true", help='check whether the current attack is non-greedy!')
parser.add_argument('--low_poison_budget',action="store_true", help='low poison regime and show enumarated results!')
parser.add_argument('--use_train',action="store_true", help='test data will be leveraged in the attack process')
parser.add_argument('--original',action="store_true", help='we will leverage the original target model generation process, not improved one')
parser.add_argument('--grad_calib_mode',default="orig", help='grad_calibration options: orig,conf,norm')
parser.add_argument('--minimize',action="store_true", help='minimize influence w.r.t. flipped label')

parser.add_argument('--if_rep_num',default=1, type=int, help='repeat n times to give better estimate on the induced model weight')
parser.add_argument('--check_model_update_every',default=0, type=int, help='how frequently to check model update, 0 means no model update!')
parser.add_argument('--test_single',action="store_true", help='test single poison optimality!')
parser.add_argument('--load_mat_target',action="store_true", help='load the target model from the mat file!')
parser.add_argument('--C',default=0.0, type=float, help='error rate for the target model')
parser.add_argument('--lr_min_C',default=0.25, type=float, help='minimum loss threshold for logistic regression')

parser.add_argument('--plot_figs',action="store_true", help='plot figures')
parser.add_argument('--extreme_start',action="store_true", help='generate the extreme points for as the start point')
parser.add_argument('--check_dist',action="store_true", help='check the distance of generated poisons')

parser.add_argument('--plot_epsilon',default=0.3, type=float, help='poisoning ratio where we only care about specific smaller ranges')
parser.add_argument('--check_transfer',action="store_true", help='check the transferability of different models')
parser.add_argument('--combine_train_test',action="store_true", help='use the combined error rates to check the performance')
parser.add_argument('--subsample',action="store_true", help='use the combined error rates to check the performance')
parser.add_argument('--use_slab',action="store_true", help='use oracale slab defense')
parser.add_argument('--use_sphere',action="store_true", help='use sphere defese')

args = parser.parse_args()

def main(args):
    if args.dataset == 'imdb':
        args.weight_decay = 0.01

    np.random.seed(args.rand_seed)
    if args.dataset == 'cifar10_89':
        epochs = [-1,1,50,90,100,120]
    else:
        epochs = [0]

    for epoch in epochs:
        args.epoch = epoch
        X_train,Y_train,X_test,Y_test,x_lims = get_dataset(args,epoch) 
        print("--- Train/Test Data Size --- ")
        print(X_train.shape,Y_train.shape,X_test.shape,Y_test.shape)
        print("Search Space: Min {}, Max {}".format(x_lims[0],x_lims[1]))

        if args.subsample:
            if args.dataset == '':
                sample_frac_test=0.1
                sample_frac_train=0.2
            elif args.dataset == 'enron':
                sample_frac_test=0.03
                sample_frac_train=0.03
            subsample_test(X_train,X_test,Y_train,Y_test,args,sample_frac_test,sample_frac_train)
            sys.exit()

        if args.dataset == 'enron' and args.model_type == 'lr':
            args.weight_decay = 0.01

        print("--- Performance of Clean Models --- ")
        clean_model = train_model(X_train,Y_train,args)       
        total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,Y_train,X_train,Y_train,\
        X_test,Y_test,clean_model,args,verbose=True)
        clean_data = [X_train,Y_train,X_test,Y_test]

        num_iter = int(args.epsilon*X_train.shape[0])
        use_test = not args.use_train
        num_points_test = 3

        # fix the default params for each dataset so that easy to run with few params
        if args.dataset == 'mnist_17':
            args.lr = 0.1
            args.num_opt_steps = 200
            args.num_restart = 10
        elif args.dataset == 'dogfish':
            args.lr = 0.1
            args.num_opt_steps = 3000
            args.num_restart = 10
        
        if args.dataset == "dogfish" and not args.use_train:
            # dogfish dataset heavily overfits to the test data used, so combine train and test data to give better picture.
            print("combining train and test data!")
            args.combine_train_test = True
        compare_different_attacks(args,x_lims,num_iter,clean_data=clean_data,use_test=use_test,combine_train_test=args.combine_train_test)

def compare_different_attacks(args,x_lims,num_iter,clean_data,use_test=True,combine_train_test=False):
    # :save_figs: whether to save the figs comparing different estimated influence for greedy optimal attack
    # get the greedy optimal attack results
    X_train,Y_train,X_test,Y_test = clean_data
    x_min,x_max = x_lims

    # prepare the clean models
    curr_model = train_model(X_train,Y_train,args)
    clean_theta, clean_bias = curr_model.coef_.reshape(-1), curr_model.intercept_[0]

    if not args.use_train:
        use_train_or_test = 'use_test'
    else:
        use_train_or_test = 'use_train'
        
    total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,\
        Y_train,X_train,Y_train,X_test,Y_test,curr_model,args,verbose=False)

    actual_clean_train_err = np.copy(1-clean_train_acc)
    actual_clean_train_loss = np.copy(clean_train_loss)
    actual_clean_test_err = np.copy(1-clean_test_acc)
    actual_clean_test_loss = np.copy(clean_test_loss)

    # get the relevant dirs
    result_dir, fig_dir_name,tar_model_dir, tar_model_dir_mat = get_needed_dirs(args,X_train)

    if args.use_sphere or args.use_slab:
        # currently, only consider both defenses, can distinguish further if needed
        assert args.use_slab and args.use_sphere
        result_dir = '{}/oracle_defense'.format(result_dir)

    check_eps = [args.epsilon]

    check_eps_to_int = np.array([int(np.round(a*X_train.shape[0]))-1 for a in check_eps])
    
    baseline_IF_test_loss = np.zeros(len(check_eps_to_int))
    baseline_IF_train_loss = np.zeros(len(check_eps_to_int))
    baseline_IF_test_error = np.zeros(len(check_eps_to_int))
    baseline_IF_train_error = np.zeros(len(check_eps_to_int))

    if args.attack in ['baseline_if','all']:
        # now, we directly load the results generated and copied from the server
        # baseline_IF_dir = '{}/baseline_IF_from_server'.format(result_dir)
        baseline_IF_dir = 'influence_attack/output/{}/influence_data/{}'.format(args.dataset,\
        args.model_type)
        if not os.path.isdir(baseline_IF_dir):
            os.makedirs(baseline_IF_dir)
        if args.model_type == 'svm':
            model_type = 'smooth_hinge'
            temp = 0.001
        else:
            model_type = 'lr'
            temp = 0.0

        if_test_preds = {}
        
        for i in range(len(check_eps)):
            eps = check_eps[i]
            # smooth_hinge_mnist_17_no_defense_step-0.1_iter-10000_t-0.001_eps-0.3_wd-0.09_rs-1234_use_test_0.2892_attack.npz
            result_fname = '{}/{}_{}_no_defense_step-0.1_iter-10000_t-{}_eps-{}_wd-{}_rs-{}_{}_attack.npz'.format(baseline_IF_dir,\
            model_type,args.dataset,temp,eps,args.weight_decay,args.rand_seed,use_train_or_test) 
            if os.path.isfile(result_fname):
                attack_results_npz = np.load(result_fname)
                X_poison = attack_results_npz['poisoned_X_train'][X_train.shape[0]:,:]
                Y_poison = attack_results_npz['Y_train'][X_train.shape[0]:]
                
                assert X_poison.shape[0] == int(np.round(eps*X_train.shape[0]))
                attack_iter = attack_results_npz['attack_iter']
                update_value = True
            else:
                update_value = False
                print("first generate the baseline IF results!")
                print(result_fname)
            if update_value:
                poison_min, poison_max = np.amin(X_poison,axis=0), np.amax(X_poison,axis=0)
                assert (poison_min >= x_min).all()
                assert (poison_max <= x_max).all()

                if args.model_type == 'lr':
                    Y_poison[Y_poison==0] = -1

                X_total = vstack(X_train,X_poison)
                Y_total = vstack(Y_train,Y_poison)
                p_model = train_model(X_total,Y_total,args)
                # since IF based attacks tend to overfit, we will report both the test and train accuracy after poisoning
                if args.dataset == 'dogfish' and combine_train_test:
                    X_test_new = vstack(X_train,X_test)
                    Y_test_new = vstack(Y_train,Y_test)
                else:
                    X_test_new = X_test
                    Y_test_new = Y_test

                total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc,\
                    clean_test_acc = test_model(X_total,Y_total,X_train,Y_train,X_test_new,Y_test_new,p_model,args,verbose=False)

                baseline_IF_train_error[i] = 1-clean_train_acc
                baseline_IF_test_error[i] = 1-clean_test_acc
                baseline_IF_train_loss[i] = clean_train_loss
                baseline_IF_test_loss[i] = clean_test_loss

                if_test_preds[eps] = p_model.predict(X_test)
        if_skip_flag = False
    else:
        if_skip_flag = True

    min_max_train_error = np.zeros(len(check_eps))
    min_max_train_loss = np.zeros(len(check_eps))
    min_max_test_error = np.zeros(len(check_eps))
    min_max_test_loss = np.zeros(len(check_eps))    
    best_poisons = {}
    for eps in check_eps:
        best_poisons[eps] = {}

    best_result_dir = '{}/best_results'.format(result_dir)
    if not os.path.isdir(best_result_dir):
        os.makedirs(best_result_dir)
    min_max_test_preds = {}    

    # mta, kkt, min_max attacks require target models
    tar_model = train_model(X_train,Y_train,args)
    
    if args.load_mat_target:
        target_errors = [0.0]
    else:
        if args.dataset == 'mnist_17':
            if args.model_type == 'svm':
                if args.epsilon <=0.1:
                    target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11]
                else:
                    target_errors = [0.1,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5]
            else:
                target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11,0.1,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5]
        elif args.dataset == 'dogfish':
            if use_train_or_test == 'use_test':
                target_errors = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8]
            else:
                target_errors = [0.03,0.05,0.07,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8]
        elif args.dataset == 'mnist_69':
            target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
        elif args.dataset == 'mnist_38':
            target_errors = [0.04,0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
        elif args.dataset == 'mnist_49':
            target_errors = [0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
            if args.model_type == 'lr':
                target_errors = [0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
        elif args.dataset == 'cifar10_05':
            target_errors = [0.03,0.05,0.07,0.45,0.5,0.55]
        elif args.dataset == 'cifar10_14':
            target_errors = [0.03,0.05,0.07,0.45, 0.55]
        elif args.dataset == 'enron':
            target_errors = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7]
        elif args.dataset == 'adult':
            target_errors = [0.23,0.25,0.3,0.33,0.35,0.4,0.45,0.5,0.55,0.6]
            if args.model_type == 'lr':
                target_errors.append(0.65)
        elif args.dataset == 'cifar10_89':
            # because cifar10_89 has many datasets, we directly load the target errors.
            target_error_fname = '{}/generated_errors.npy'.format(tar_model_dir)
            target_errors = np.load(target_error_fname)
        else:
            print("please provide a valid dataset name!")

    # record the best result from batch of target models tested
    # we also record the lowest curr and target loss observed
    best_mta_train_loss = (-1)*np.ones(num_iter)
    best_mta_train_error = (-1)*np.ones(num_iter)
    best_mta_test_loss = (-1)*np.ones(num_iter)
    best_mta_test_error = (-1)*np.ones(num_iter)
    best_mta_target_error_loc = np.zeros(num_iter)
    best_mta_curr_loss = np.zeros(num_iter)
    best_mta_tar_loss = np.zeros(num_iter)

    best_kkt_train_loss = (-1)*np.ones(num_iter)
    best_kkt_train_error = (-1)*np.ones(num_iter)
    best_kkt_test_loss = (-1)*np.ones(num_iter)
    best_kkt_test_error = (-1)*np.ones(num_iter)
    best_kkt_target_error_loc = np.zeros(num_iter)

    # store the best poison attacks
    best_mta_poisons = {}
    best_max_loss_poisons = {}
    best_kkt_poisons = {}
    best_min_max_poisons = {}
    for eps in check_eps:
        best_mta_poisons[eps] = {}
        best_mta_poisons[eps]['test_error'] = -1
        best_max_loss_poisons[eps] = {}
        best_max_loss_poisons[eps]['test_error'] = -1
        best_kkt_poisons[eps] = {}
        best_kkt_poisons[eps]['test_error'] = -1
        best_min_max_poisons[eps] = {}
        best_min_max_poisons[eps]['test_error'] = -1

    mta_test_preds = {}
    max_loss_test_preds = {}
    kkt_test_preds = {}
    min_max_test_preds = {}

    min_max_skipped_tar_errs = []
    kkt_skipped_tar_errs = []
    mta_skipped_tar_errs = []
    max_loss_skipped_tar_errs = []

    if combine_train_test:
        mta_combined_errs = {}
        for eps in check_eps:
            mta_combined_errs[eps] = -1

        max_loss_combined_errs = {}
        kkt_combined_errs = {}
    for target_error in target_errors:
        if args.load_mat_target:
            if not args.use_train:
                tar_model_fname = '{}/{}_thetas_with_bias_exact_decay_09_v3_prune.mat'.format(tar_model_dir_mat,args.dataset)
            else:
                tar_model_fname = '{}/{}_thetas_with_bias_exact_decay_09_use_train_v3_prune.mat'.format(tar_model_dir_mat,args.dataset)
        else:
            if not args.original:
                tar_model_fname = '{}/improved_best_theta_whole_err-{}'.format(tar_model_dir,target_error)
            else:
                tar_model_fname = '{}/orig_best_theta_whole_err-{}'.format(tar_model_dir,target_error)

        # load the target model related stuffs
        if os.path.isfile(tar_model_fname):
            file_to_read = open(tar_model_fname,"rb")
            f = pickle.load(file_to_read)
            best_target_theta = f['best_theta']
            best_target_bias = f['best_bias']
            tar_model_train_loss = f['best_train_loss'] 
            tar_model_train_err = f['best_train_error'] 
            tar_model_test_loss = f['best_test_loss'] 
            tar_model_test_err = f['best_test_error'] 
            print("---- Loaded Target Model Info----") 
            print("Train Error: {:.5f}, Train Loss: {:.5f}, Test Error: {:.5f}, Test Loss: {:.5f}".format(tar_model_train_err,tar_model_train_loss,\
                tar_model_test_err,tar_model_test_loss))
            thetas = [best_target_theta]
            biases = [best_target_bias]
        else:
            print("Target Model File {} does not exist!".format(tar_model_fname))
            thetas = [tar_model.coef_.reshape(-1)]
            biases = [tar_model.intercept_]

        if args.load_mat_target:
            tar_model_type = 'original'
        else:
            if not args.original:
                tar_model_type = 'improved'
            else:
                tar_model_type = 'original'

        for iiii in range(len(thetas)):
            theta = thetas[iiii]
            bias = biases[iiii] 
            tar_model.coef_= np.array([theta])
            if args.load_mat_target:
                tar_model.intercept_ = np.array([bias])
            else:
                tar_model.intercept_ = bias

            if args.attack in ['mta','all']: 
                mta_skip_flag = False
                if args.load_mat_target:
                    target_error = test_errors[iiii]
                    mta_result_dir = '{}/mta_results/mat_target'.format(result_dir)
                    result_fname = '{}/mta_{}_of_{}_{}_target_model_err-{:.5f}_{}_results'.format(mta_result_dir,num_iter,X_train.shape[0],tar_model_type,target_error,use_train_or_test)
                else:
                    mta_result_dir = '{}/mta_results'.format(result_dir)
                    result_fname = '{}/mta_{}_of_{}_{}_target_model_err-{}_{}_results'.format(mta_result_dir,num_iter,X_train.shape[0],tar_model_type,target_error,use_train_or_test)
                if not os.path.isdir(mta_result_dir):
                    os.makedirs(mta_result_dir)
                
                if os.path.isfile(result_fname):
                    print("directly loading results of mta of target error {:.5f}".format(target_error))
                    file_to_read = open(result_fname,"rb")
                    results_dict = pickle.load(file_to_read)
                    file_to_read.close()
                    mta_results = results_dict['mta_results'] 
                    mta_curr_and_tar_losses = results_dict['mta_loss_on_curr_and_tar']
                else:
                    print("[[Warning]]!!!! mta with {} model of target error {:.5f} does not exist, skipping!".format(tar_model_type,target_error))
                    print(result_fname)
                    mta_skipped_tar_errs.append(target_error)
                    mta_skip_flag = True

                if not mta_skip_flag:
                    # process the mta results to ensure that we produce best result using exhautive search on target models
                    tmp_mta_train_loss = mta_results[-4]
                    tmp_mta_train_error = mta_results[-3]
                    tmp_mta_test_loss = mta_results[-2]
                    tmp_mta_test_error = mta_results[-1]

                    # we are interested in finding target models that maximize test error
                    update_ids = tmp_mta_test_error > best_mta_test_error
                    best_mta_train_loss[update_ids] = tmp_mta_train_loss[update_ids]
                    best_mta_train_error[update_ids] = tmp_mta_train_error[update_ids]
                    best_mta_test_loss[update_ids] = tmp_mta_test_loss[update_ids]
                    best_mta_test_error[update_ids] = tmp_mta_test_error[update_ids]
                    best_mta_target_error_loc[update_ids] = target_error
                    best_mta_curr_loss[update_ids] = mta_curr_and_tar_losses[0][update_ids]
                    best_mta_tar_loss[update_ids] = mta_curr_and_tar_losses[1][update_ids]

                    # save the best attack for the mta attack
                    for eps in check_eps:
                        id = int(eps*X_train.shape[0])
                        if tmp_mta_test_error[id-1] > best_mta_poisons[eps]['test_error']:
                            best_mta_poisons[eps]['X_poison'] = mta_results[0]
                            best_mta_poisons[eps]['Y_poison'] = mta_results[1]
                            best_mta_poisons[eps]['test_error'] = tmp_mta_test_error[id-1]
                        # save the poisoned points and make predictions subsequently
                        X_poison = mta_results[0][:id]
                        Y_poison = mta_results[1][:id]\

                        X_total = vstack(X_train,X_poison)
                        Y_total = vstack(Y_train,Y_poison)
                        p_model = train_model(X_total,Y_total,args)
                        mta_test_preds[eps] = p_model.predict(X_test)
                        if combine_train_test:
                            X_test_new = vstack(X_train,X_test)
                            Y_test_new = vstack(Y_train,Y_test)

                            total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc,\
                                    clean_test_acc = test_model(X_train,Y_train,X_train,Y_train,X_test_new,Y_test_new,p_model,args,verbose=False)    
                            if mta_combined_errs[eps] < 1-clean_test_acc:
                                mta_combined_errs[eps] = 1-clean_test_acc
                mta_skip_flag_overall = False
            else:
                mta_skip_flag_overall = True

            for kkt_epsilon in check_eps:
                kkt_num_iter = int(kkt_epsilon * X_train.shape[0])
                if args.attack in ['kkt','all']:
                    # start processing the kkt attack
                    kkt_skip_flag = False
                    if args.load_mat_target:
                        target_error = test_errors[iiii]
                        kkt_result_dir = '{}/kkt_results/mat_target'.format(result_dir)
                        result_fname = '{}/kkt_{}_of_{}_{}_target_model_err-{:.5f}_{}_results'.format(max_loss_result_dir,kkt_num_iter,X_train.shape[0],\
                            tar_model_type,target_error,use_train_or_test)
                    else:
                        kkt_result_dir = '{}/kkt_results'.format(result_dir)
                        result_fname = '{}/kkt_{}_of_{}_{}_target_model_err-{}_{}_results'.format(kkt_result_dir,\
                        kkt_num_iter,X_train.shape[0],tar_model_type,target_error,use_train_or_test)

                    print(result_fname)
                    if os.path.isfile(result_fname):
                        print("directly loading results of KKT of target error {:.5f}, epsilon {}!".format(target_error,kkt_epsilon))
                        file_to_read = open(result_fname,"rb")
                        results_dict = pickle.load(file_to_read)
                        file_to_read.close()
                        kkt_results = results_dict['kkt_results'] 
                        best_kkt_results = results_dict['best_kkt_results']
                    else:    
                        kkt_skip_flag = True  
                        print("[warning]!!!!KKT attack of target error {:.5f} is missing!".format(target_error))
                        print(result_fname)
                        kkt_skipped_tar_errs.append(target_error)

                    if not kkt_skip_flag:
                        #best_x, best_y, all_thetas,train_victim_losses,train_01_losses,test_victim_losses,test_01_losses = kkt_results
                        tmp_kkt_train_loss = best_kkt_results[-4]
                        tmp_kkt_train_error = best_kkt_results[-3]
                        tmp_kkt_test_loss = best_kkt_results[-2]
                        tmp_kkt_test_error = best_kkt_results[-1]

                        if tmp_kkt_test_error > best_kkt_poisons[kkt_epsilon]['test_error']: 
                            best_kkt_poisons[kkt_epsilon]['X_poison'] = kkt_results[0]
                            best_kkt_poisons[kkt_epsilon]['Y_poison'] = kkt_results[1]
                            best_kkt_poisons[kkt_epsilon]['test_error'] = tmp_kkt_test_error
                            best_kkt_poisons[kkt_epsilon]['test_loss'] = tmp_kkt_test_loss
                            best_kkt_poisons[kkt_epsilon]['train_error'] = tmp_kkt_train_error
                            best_kkt_poisons[kkt_epsilon]['train_loss'] = tmp_kkt_train_loss    
                            best_kkt_poisons[kkt_epsilon]['best_tar_error'] = target_error             
                            # save the poisoned points and make predictions subsequently
                            X_poison = kkt_results[0]
                            Y_poison = kkt_results[1]
                            X_total = vstack(X_train,X_poison)
                            Y_total = vstack(Y_train,Y_poison)
                            p_model = train_model(X_total,Y_total,args)
                            kkt_test_preds[kkt_epsilon] = p_model.predict(X_test)
                            if combine_train_test and args.dataset == 'dogfish':
                                X_test_new = vstack(X_train,X_test)
                                Y_test_new = vstack(Y_train,Y_test)
                                p_model = train_model(X_total,Y_total,args)
                                total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc,\
                                    clean_test_acc = test_model(X_total,Y_total,X_train,Y_train,X_test_new,Y_test_new,p_model,args,verbose=False)
                                kkt_combined_errs[kkt_epsilon] = 1-clean_test_acc

                    kkt_skip_flag_overall = False
                else:
                    kkt_skip_flag_overall = True                    

                # handle the min-max attack
                num_sgd_steps = 1
                min_max_lr_matlab = 0.03
                if args.dataset in ['mnist_17','mnist_38','mnist_69','cifar10_05']:
                    burn_frac = max(0.33, 0.02/kkt_epsilon-1) 
                elif args.dataset == 'dogfish':
                    burn_frac = max(1.0, 0.10/kkt_epsilon-1)
                elif args.dataset == 'enron':
                    burn_frac = max(1.0, 0.10/kkt_epsilon-1)
                else:
                    burn_frac = 0.0
                _num_iter = int(kkt_epsilon * X_train.shape[0])
                C = args.C

                min_max_skip_flag = False
                min_max_dir = '{}/min_max_results'.format(result_dir)
                result_fname = '{}/min_max_{}_of_{}_{}_target_model_err-{:.5f}_SGDSteps-{}_lr-{}_burnin-{}_C-{}_{}_results'.format(min_max_dir,_num_iter,X_train.shape[0],\
                tar_model_type,target_error,num_sgd_steps,min_max_lr_matlab,burn_frac,C,use_train_or_test)
                
                print(result_fname)
                if os.path.isfile(result_fname):
                    print("directly loading results of Orig Min Max of target error {:.5f}, epsilon {}!".format(target_error,kkt_epsilon))
                    file_to_read = open(result_fname,"rb")
                    results_dict = pickle.load(file_to_read)
                    file_to_read.close()
                    min_max_results = results_dict['min_max_results'] 
                else:    
                    min_max_skip_flag = True  
                    print("[warning]!!!! Min Max attack of target error {:.5f} is missing!".format(target_error))
                    print(result_fname)
                    min_max_skipped_tar_errs.append(target_error)

                if not min_max_skip_flag:
                    tmp_min_max_train_loss = min_max_results[-4]
                    tmp_min_max_train_error = min_max_results[-3]
                    tmp_min_max_test_loss = min_max_results[-2]
                    tmp_min_max_test_error = min_max_results[-1]

                    if tmp_min_max_test_error[-1] > best_min_max_poisons[kkt_epsilon]['test_error']: 
                        best_min_max_poisons[kkt_epsilon]['X_poison'] = min_max_results[0]
                        best_min_max_poisons[kkt_epsilon]['Y_poison'] = min_max_results[1]
                        best_min_max_poisons[kkt_epsilon]['test_error'] = tmp_min_max_test_error[-1]
                        best_min_max_poisons[kkt_epsilon]['test_loss'] = tmp_min_max_test_loss[-1]
                        best_min_max_poisons[kkt_epsilon]['train_error'] = tmp_min_max_train_error[-1]
                        best_min_max_poisons[kkt_epsilon]['train_loss'] = tmp_min_max_train_loss[-1]    
                        best_min_max_poisons[kkt_epsilon]['best_tar_error'] = target_error             
                        # save the poisoned points and make predictions subsequently
                        X_poison = min_max_results[0]
                        Y_poison = min_max_results[1]
                        X_total = vstack(X_train,X_poison)
                        Y_total = vstack(Y_train,Y_poison)
                        p_model = train_model(X_total,Y_total,args)
                        min_max_test_preds[kkt_epsilon] = p_model.predict(X_test)
                
                min_max_skip_flag_overall = False
            else:
                min_max_skip_flag_overall = True

    # save the kkt attack results
    kkt_test_errors = []
    # kkt_train_errors = []
    kkt_test_losses = []
    mta_test_errors_new = []
    min_max_test_errors = []
    min_max_test_losses = []

    for _eps in check_eps:
        if args.attack in ['all','mta'] and combine_train_test:
            print(_eps,mta_combined_errs[_eps])
            mta_test_errors_new.append(mta_combined_errs[_eps])

        if args.attack in ['min_max','all']:
            min_max_test_errors.append(best_min_max_poisons[_eps]['test_error'] )
            min_max_test_losses.append(best_min_max_poisons[_eps]['test_loss'])
            print("--- Min-Max Eps-{} Info ---".format(_eps))
            best_result_fname = '{}/min_max_lr_eps-{}'.format(best_result_dir,_eps)
            file_to_write = open(best_result_fname, "wb")
            pickle.dump(best_min_max_poisons, file_to_write)
            file_to_write.close()
            # unique, counts = np.unique(best_min_max_target_error_loc, return_counts=True)
            print("Best Error Locations")
            print(best_min_max_poisons[_eps]['best_tar_error'])

        if args.attack in ['kkt','all']:
            kkt_test_errors.append(best_kkt_poisons[_eps]['test_error'] )
            kkt_test_losses.append(best_kkt_poisons[_eps]['test_loss'])
            print("--- KKT Eps-{} Info ---".format(_eps))
            best_result_fname = '{}/kkt_eps-{}'.format(best_result_dir,_eps)
            file_to_write = open(best_result_fname, "wb")
            pickle.dump(best_kkt_poisons, file_to_write)
            file_to_write.close()
            # unique, counts = np.unique(best_min_max_target_error_loc, return_counts=True)
            print("Best Error Locations")
            print(best_kkt_poisons[_eps]['best_tar_error'])

    print(mta_test_errors_new)

    print("------ Test Error Summary of All Attacks ------ ")
    if args.attack in ['baseline_if','all']:
        print('Influence Attack:',baseline_IF_test_error)
    if args.attack in ['kkt','all']:
        print("KKT:",kkt_test_errors)
        # print("KKT train:",kkt_train_errors) 
        print("KKT skipped target errors",kkt_skipped_tar_errs)
    if args.attack in ['mta','all']:
        if combine_train_test:
            print(mta_combined_errs)
            print("mta errors:",mta_test_errors_new)
        else:
            print("mta:",best_mta_test_error[check_eps_to_int])
            print("mta train:",best_mta_train_error[check_eps_to_int])
        print("mta skipped target errors",mta_skipped_tar_errs)

    if args.attack in ['min_max','all']:
        print("Min-Max:",min_max_test_errors)
        print("Min Max skipped target errors",min_max_skipped_tar_errs)

main(args)
